"""
:mod:`ng911ok.lib.validation` -- Support for validation logic

Notes
----------
:synopsis:
:authors: Riley Baird (OK), Emma Baker (OK)
:created: September 03, 2024
:modified: December 16, 2024
"""


# checkFeatureLocations              [seems to be redundant when a proper topology is used]
# checkAddressPointGEOMSAG           [TODO: Add into check_addresses_against_roads]
# checkDirectionality                check_address_range_directionality
# checkRequiredFields                check_feature_class_configuration
# checkLayerList                     check_geodatabase_configuration
# checkRoadESNOK                     check_road_esn
# checkSubmissionNumbers             check_submission_counts
# checkRequiredFieldValues           check_attributes
# launchRangeFinder                  [TODO: Needs MSAG implementation]
# checkESNandMuniAttribute           check_address_point_esn
# checkNextGenAgainstLegacyFields    check_next_gen_against_legacy [TODO: Finish LegacyFieldInfo.compare_columns() when functions have been written to compute LgcyFulSt/LgcyFulAdd; right now, just assumes those fields are correct.]
# checkTopology                      TODO: check_topology [will require Standard license]
# checkFrequency                     check_uniqueness
# checkSpatialReference              check_geodatabase_configuration
# checkGDBDomains                    check_gdb_domains
# checkRCLMATCH                      check_addresses_against_roads
# checkMsagLabelCombo                [not implementing at this time]
# checkCutbacks                      check_road_geometry
# directionOfPoint                   [only used in ESN check; can probably be replaced with numpy/pandas]
# checkValuesAgainstDomain           check_fields_against_domains
# checkMSAGCOspaces                  [Merged into check_attributes]
# checkUniqueIDFrequency             check_unique_id_frequency
# findInvalidGeometry                [merged into check_feature_class_configuration]
# FindOverlaps                       check_address_ranges
# checkParities                      check_parities
# checkUniqueIDFormat                check_unique_id_format
# [Check Road From/To Level]         check_road_level


from abc import abstractmethod
from collections.abc import Callable
from datetime import datetime
from typing import Optional, Literal, Union, TypeVar, Generic, Self, Protocol

import attrs
import pandas as pd
# noinspection PyUnresolvedReferences
from arcgis.features import GeoSeriesAccessor, GeoAccessor

from .misc import FeatureAttributeValue, Severity, GDBErrorCode, FeatureAttributeErrorCode

T_Code = TypeVar("T_Code", GDBErrorCode, FeatureAttributeErrorCode)


@attrs.define
class _ValidationErrorInfo(Generic[T_Code]):
    severity: Severity
    code: T_Code


class ValidationErrorMessage(Protocol[T_Code]):
    severity: Severity
    code: T_Code
    _message: str | Callable[[Self], str]
    message: str

    @abstractmethod
    def to_series(self, timestamp: datetime) -> pd.Series: ...


@attrs.frozen
class FeatureAttributeErrorMessage(_ValidationErrorInfo[FeatureAttributeErrorCode], ValidationErrorMessage[FeatureAttributeErrorCode]):
    layer1: str
    nguid1: str
    field1: str
    value1: FeatureAttributeValue
    layer2: Optional[str]
    nguid2: Optional[str]
    field2: Optional[str]
    value2: FeatureAttributeValue
    _message: str | Callable[[Self], str] = attrs.field(eq=False)
    message: str = attrs.field(
        default=attrs.Factory(
            lambda self: self._message(self) if callable(self._message) else self._message,
            takes_self=True
        ),
        init=False
    )

    @classmethod
    def one_feature(cls, severity: Severity, code: FeatureAttributeErrorCode, layer1: str, nguid1: str, field1: str, value1: FeatureAttributeValue, message: str | Callable[[Self], str]) -> Self:
        return cls(severity, code, layer1, nguid1, field1, value1, None, None, None, None, message)

    @classmethod
    def from_df(cls, data: pd.DataFrame, validity: pd.DataFrame, severity: Severity, code: FeatureAttributeErrorCode, layer: str, message: str | Callable[[Self], str]) -> list[Self]:
        """
        Uses an input ``pd.DataFrame`` containing the data and another
        ``pd.DataFrame`` of ``bool`` values, whose shape and index correspond
        to those of *data*, indicating whether the values of *data* are valid,
        to handle error generation for many attributes at once.

        **Previously, the following applied:**

        The *data* and *validity* arguments **must have an index that is also
        equivalent to a column in the data.** This can be accomplished by
        calling, e.g., ``df.set_index(index_column_name, drop=False)`` before
        passing ``df`` as the *data* or *validity* argument.
        """
        # if data.index.name not in data.columns:
            # raise ValueError("data must contain a column whose name is equal to the name of the index.")
            # data = data.copy()
            # data[data.index.name] = data.index.values
        if data.empty:
            return []
        nguid = data.attrs.get("nguid_name", data.index.name)
        validity = validity.astype(pd.BooleanDtype())
        data = data[~validity].melt(ignore_index=False).dropna().reset_index()
        errors = data.apply(lambda row: cls.one_feature(severity, code, layer, row[nguid], row.variable, row.value, message), axis=1)
        if errors.empty:
            return []
        else:
            return errors.to_list()

    @classmethod
    def from_df_two_fields(cls, data: pd.DataFrame, validity: pd.Series, field1: str, field2: str, severity: Severity, code: FeatureAttributeErrorCode, layer: str, message: str | Callable[[Self], str]) -> list[Self]:
        """
        Similar to ``from_df()``, but generates error messages involving two fields in a single feature class. Unlike in ``from_df()``, however, *validity* should be a ``Series``, not a ``DataFrame``, and it should represent the validity of the two fields by row. For rows where *validity* is ``False``, the values in the columns named *field1* and *field2* in *data* will be used to generate the error messages.

        In the resulting instance, ``layer2`` and `nguid2`` will be ``None``.
        """
        if data.empty:
            return []
        # if data.index.name not in data.columns:
        #     raise ValueError("data must contain a column whose name is equal to the name of the index.")
        if missing_columns := {field1, field2} - set(data.columns):
            raise ValueError(f"Argument for 'data' is missing expected column(s): {', '.join(missing_columns)}")
        errors: pd.Series = data[~validity].apply(lambda row: cls(severity, code, layer, row.name, field1, row[field1], None, None, field2, row[field2], message), axis=1)
        if errors.empty:
            return []
        else:
            return errors.to_list()

    @classmethod
    def from_joined_df(cls, data: pd.DataFrame, column_pairs: dict[str, str], use_left_column_names: bool, severity: Severity, code: FeatureAttributeErrorCode, index_layer: str, join_layer: str, join_layer_id_name: str, message: str | Callable[[Self], str]) -> list[Self]:
        """
        Similar to ``from_df()``, but generates error messages for situations that meet the following criteria:

        * Data from two different feature classes ("left FC"/*index_layer* and "right FC"/*join_layer*) are joined in a single data frame
        * Columns are paired in the sense that a value from a column from *left FC* should equal the value in the corresponding column from *right FC*
        * The values of a column in *left_fields* should equal the values of the column in the corresponding position in *right_fields*
        * The index of *data* is the index/NGUID of *left FC*
        * The index/NGUID of *right FC* is provided as a column (*join_layer_id_name*)

        This method has no *validity* parameter; data validity is computed automatically.

        Example::

            left_df = pd.DataFrame({
                "nguid": ["left1", "left2", "left3", "left4"],
                "street": ["PECAN", "WALNUT", "CHERRY", "MAPLE"],
                "streettype": ["STREET", "AVENUE", "DRIVE", "BOULEVARD"],
                "match_key": ["right1", "right2", "right3", "right4"]
            }).set_index("nguid", drop=False)
            right_df = pd.DataFrame({
                "nguid": ["right1", "right2", "right3", "right4"],
                "street": ["PECAN", "WALNUT", "CHERRY", "CHESTNUT"],
                "streettype": ["STREET", "AVE", "DRIVE", "BOULEVARD"]
            }).set_index("nguid", drop=False)
            data = left_df.join(right_df, on="match_key", rsuffix="_right")
            left_columns = ["street", "streettype"]
            right_columns = [f"{col}_right" for col in left_columns]
            validity: pd.DataFrame = data.applymap(lambda _: True)
            validity.loc[:, left_columns] = data[left_columns].values == data[right_columns].values

        :param data: The input data frame
        :type data: pd.DataFrame
        :param column_pairs: Mapping of corresponding column names as {left: right}
        :type column_pairs: dict[str, str]
        :param use_left_column_names: Whether to use the keys in *column_pairs*
            for both ``field1`` and ``field2`` in the output
        :type use_left_column_names: bool
        :param severity: The severity of the messages
        :type severity: Severity
        :param code: The specific error code of the messages
        :type code: FeatureAttributeErrorCode
        :param index_layer: The name of the layer on the left side of the join
            which retains an NGUID as the index of *data*
        :type index_layer: str
        :param join_layer: The name of the layer on the right side of the join
            that produced *data*
        :type join_layer: str
        :param join_layer_id_name: The name of the **column** (not necessarily
            field) containing NGUIDs for *join_layer*
        :type join_layer_id_name: str
        :param message: Message or message-generating function to be passed to
            the output
        :type message: Union[str, ValidationErrorMessageFunction[FeatureAttributeErrorInfo]]
        :return: Derived feature attribute errors
        :rtype: list[Self]
        """
        if data.empty:
            return []
        # if data.index.name not in data.columns:
        #     raise ValueError("data must contain a column whose name is equal to the name of the index.")
        left_columns: list[str] = [*column_pairs.keys()]
        right_columns: list[str] = [*column_pairs.values()]
        if missing_columns := {*left_columns, *right_columns, join_layer_id_name} - set(data.columns):
            raise ValueError(f"Argument for 'data' is missing expected column(s): {', '.join(missing_columns)}")

        validity: pd.DataFrame = data.applymap(lambda _: True)
        validity.loc[:, left_columns] = data[left_columns].values == data[right_columns].values

        def _summarize(row: pd.Series):
            left_nguid = row.name
            right_nguid = data.loc[left_nguid, join_layer_id_name]
            left_variable = row.variable
            left_value = data.loc[left_nguid, left_variable]
            right_variable = column_pairs[row.variable]
            right_value = data.loc[left_nguid, right_variable]
            return pd.Series(
                [left_nguid, left_variable, left_value, right_nguid, right_variable, right_value],
                ["left_nguid", "left_variable", "left_value", "right_nguid", "right_variable", "right_value"]
            )
        melted: pd.DataFrame = data[~validity].melt(ignore_index=False)
        error_summary_df: pd.DataFrame = melted.dropna().apply(_summarize, axis=1).reset_index(drop=True)
        errors = error_summary_df.apply(
            lambda row: cls(severity, code, index_layer, row.left_nguid, row.left_variable, row.left_value, join_layer, row.right_nguid, row.left_variable if use_left_column_names else row.right_variable, row.right_value, message),
            axis=1
        )
        if errors.empty:
            return []
        else:
            return errors.to_list()

    @property
    def feature_count(self) -> Literal[0, 1, 2]:
        if self.nguid1 and self.nguid2:
            return 2
        elif self.nguid1:
            return 1
        else:
            return 0

    def to_series(self, timestamp: datetime) -> pd.Series:
        return pd.Series({
            "Timestamp": timestamp,
            "Severity": self.severity,
            "Code": self.code,
            "Layer1": self.layer1,
            "NGUID1": self.nguid1,
            "Field1": self.field1,
            "Value1": self.value1,
            "Layer2": self.layer2,
            "NGUID2": self.nguid2,
            "Field2": self.field2,
            "Value2": self.value2,
            "Message": self.message
        }, dtype=object)


@attrs.frozen
class GDBErrorMessage(_ValidationErrorInfo[GDBErrorCode], ValidationErrorMessage[GDBErrorCode]):
    layer: Optional[str] = None
    field: Optional[str] = None
    _message: str | Callable[[Self], str] = attrs.field(default=None, eq=False)
    message: str = attrs.field(
        default=attrs.Factory(
            lambda self: self._message(self) if callable(self._message) else self._message,
            takes_self=True
        ),
        init=False
    )

    def to_series(self, timestamp: datetime) -> pd.Series:
        return pd.Series({
            "Timestamp": timestamp,
            "Severity": self.severity,
            "Code": self.code,
            "Layer": self.layer,
            "Field": self.field,
            "Message": self.message
        })


ValidationErrorMessage_co = TypeVar("ValidationErrorMessage_co", bound=ValidationErrorMessage, covariant=True)


# AddressRangeDtype = np.dtype([("addfrom", np.uint32), ("addto", np.uint32), ("parity", np.str_, 1)])


# @attrs.frozen
# class AddressRangeOverlapReport(object):
#     nguid1: Union[str, int]
#     range1: AddressRange
#     nguid2: Union[str, int]
#     range2: AddressRange
#
#     @property
#     @cache
#     def overlap_range(self) -> AddressRange:
#         return self.range1 & self.range2
#
#     @property
#     @cache
#     def is_overlap(self) -> bool:
#         return bool(self.overlap_range)
#
#     __bool__ = is_overlap
